{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tensorflow神经网络模型训练与冻结"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 导入依赖"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from keras.utils import to_categorical # 用于转换label为one-hot矩阵\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 导入数据集\n",
    "\n",
    "XO数据集是从EMNIST数据集经过预处理得到的，只筛选XO相关的数据,并且将向量转换为了2D图像。\n",
    "直接导入序列化后的数据集二进制文件。\n",
    "\n",
    "\n",
    "训练集里面, 字符`O`跟字符`X`各4800个样本。\n",
    "\n",
    "测试集里面, 字符`O`跟字符`X`各800个样本。\n",
    "\n",
    "他们的顺序是打乱的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "dataset = None\n",
    "\n",
    "with open('../../common/xo_dataset.bin', 'rb') as f:\n",
    "    dataset = pickle.load(f)\n",
    "    \n",
    "# 训练集图像向量\n",
    "train_images = dataset['train_images']\n",
    "# 训练集标签\n",
    "train_labels = dataset['train_labels']\n",
    "# 测试集图像向量\n",
    "test_images = dataset['test_images']\n",
    "# 测试集标签\n",
    "test_labels = dataset['test_labels']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 样本图片\n",
    "每个样本图片的尺寸是28x28的灰度图\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(28, 28)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_images[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fa817955748>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAENBJREFUeJzt3WuMlFWex/HfX9BWQa6T7SCD2+NETQTiJUSJdsTV1XgN6AuCLwyTmQzzYjBOsi/WuIlrXDcZNzuzmVeTMNEMbGbFTQA1akZmySqsMUQwLNcdRG0dkMsIAg2R+39f9MNui/38T1G3p5rz/SSkq+rfp+pQ3b9+quo85xxzdwHIz0VVdwBANQg/kCnCD2SK8AOZIvxApgg/kCnCD2SK8AOZIvxApka288HMjNMJgRZzd6vl+xo68pvZfWb2RzPbYWZPNXJfANrL6j2338xGSNou6R5JOyV9IOkxd98atOHID7RYO478t0ja4e6fuPsJSUslzW7g/gC0USPhnyzpT4Ou7yxu+wYzW2Bm68xsXQOPBaDJWv6Bn7svkrRI4mU/0EkaOfLvkjRl0PXvFrcBGAYaCf8Hkq4xs++Z2SWS5kl6vTndAtBqdb/sd/dTZrZQ0tuSRkh6yd23NK1nAFqq7qG+uh6M9/xAy7XlJB8AwxfhBzJF+IFMEX4gU4QfyBThBzLV1vn8ubroovhvbKp+oTpz5kxDdTQmz986AIQfyBXhBzJF+IFMEX4gU4QfyBRDfTWKhuO6urrCtqNHjw7rY8eOratPw8GpU6dKa0ePHg3bfvXVV3XfN9I48gOZIvxApgg/kCnCD2SK8AOZIvxApgg/kCnG+QsTJkwI69OnTy+tPfjgg2Hbq6++OqxPnTo1rA/nKb/9/f2ltb6+vrDt6tWrw/rSpUvD+oEDB0prnCPAkR/IFuEHMkX4gUwRfiBThB/IFOEHMkX4gUw1NM5vZn2S+iWdlnTK3Wc0o1NVGDNmTFi/4YYbSmtz5swJ26bm648bNy6sm9W06WpHOnnyZGltypQpYduJEyeG9fXr14f1HTt2lNb27dsXtm3n7tVVacZJPn/l7l824X4AtBEv+4FMNRp+l7TSzNab2YJmdAhAezT6sr/X3XeZ2V9I+oOZ/Y+7f+OE7OKPAn8YgA7T0JHf3XcVX/dJWiHpliG+Z5G7zxjOHwYCF6K6w29mo8zsirOXJd0raXOzOgagtRp52d8taUUxDDVS0r+5+++b0isALWftHM80s8oGTy+55JKwPm/evLD+6KOPltYeeuihsG1qPn5qHL+Rn9Hp06dbdt+1iP7vqeclOkdAkrZu3RrWN2zYUFp7/vnnw7ZffPFFWP/666/DepXcvaYTQxjqAzJF+IFMEX4gU4QfyBThBzJF+IFMZbN0d2ra7MMPPxzWZ86cWVprdGntY8eOhfW9e/fW3f7tt98O20bLW9ci9X+PliXv6ekJ21511VVhPVpOXZKuvfba0trIkfGv/ltvvRXWly9fHtaPHz8e1jsBR34gU4QfyBThBzJF+IFMEX4gU4QfyBThBzLFOH9h/PjxYT01JTiSmjabmj66cuXKsB5tN71s2bKGHjtlxIgRYf36668vraW2Lr/jjjvC+l133RXWu7q6Smu33npr2Da1hfeqVavCempp8E7AkR/IFOEHMkX4gUwRfiBThB/IFOEHMkX4gUxdMEt3p+ZnP/HEE2H9ueeeC+uXX355ae3EiRNh29RY+jPPPBPW33jjjbAeLf199OjRsG1qeexGNbJ0d2o+/+OPPx7Wb7rpptLa3XffHbY9cuRIWF+4cGFYf/XVV8N6akn1RrB0N4AQ4QcyRfiBTBF+IFOEH8gU4QcyRfiBTCXn85vZS5IekrTP3acVt02Q9IqkHkl9kua6+1et62ZaapvrMWPGhPXUfP3o/g8dOhS23bhxY1hfu3ZtWO/v7w/r0bka7TyPYyhnzpypqyZJu3btCuuvvfZa3e1T8/lHjx4d1qdNmxbW33zzzbDeynH+WtVy5P+tpPvOue0pSavc/RpJq4rrAIaRZPjdfbWkc7d1mS1pcXF5saQ5Te4XgBar9z1/t7vvLi7vkdTdpP4AaJOG1/Bzd4/O2TezBZIWNPo4AJqr3iP/XjObJEnF19LVCt19kbvPcPcZdT4WgBaoN/yvS5pfXJ4vKf7YFUDHSYbfzF6W9L6k68xsp5n9SNLPJd1jZh9J+uviOoBhJPme390fKynFE6JbIJr/nVqXP7WXe2r9+Whcds2aNWHbpUuXhvVPP/00rKfGwy9UqT3uN23aFNajdRTuu+/c0etv6u3tDeuPPPJIWF+xYkVY3759e2nt2LFjYdtm4Qw/IFOEH8gU4QcyRfiBTBF+IFOEH8jUsNqiOxqOS22xnZqCmXLgwLlzm/7fK6+8ErZNDQV2wvTO4Sj1vEU/s/fffz9s29PTE9avu+66sJ5aGvzw4cOltb6+vrBts3DkBzJF+IFMEX4gU4QfyBThBzJF+IFMEX4gU8NqnD/ahju11HJqae7UtNloq+vUlNzU0ttojehn+sknn4RtU2PtU6dODeup38fUlvLtwJEfyBThBzJF+IFMEX4gU4QfyBThBzJF+IFMVT/YeB66u8u3BJw1a1bdbSXp5MmTYT2aG75///6G7hutEY3zb926NWy7ZcuWsH7//feH9WiZ+U7R+T0E0BKEH8gU4QcyRfiBTBF+IFOEH8gU4QcylRznN7OXJD0kaZ+7Tytue1bSjyX9ufi2p939rVZ18qxG5vOn5k8fPHgwrEfzu6O5/lK+W2x3stSa/zn8zGo58v9W0lCbmf+Lu99Y/Gt58AE0VzL87r5aUvnpbQCGpUbe8y80s41m9pKZxXtlAeg49Yb/15K+L+lGSbsl/aLsG81sgZmtM7N1dT4WgBaoK/zuvtfdT7v7GUm/kXRL8L2L3H2Gu8+ot5MAmq+u8JvZpEFXH5G0uTndAdAutQz1vSzpTknfMbOdkv5e0p1mdqMkl9Qn6Sct7COAFkiG390fG+LmF1vQl4ak5k+bWVg/dOhQWN+2bVtpLdprXZLcPayjNaKf+bhx48K2Y8aMqfu+hwvO8AMyRfiBTBF+IFOEH8gU4QcyRfiBTA2rpbtPnTpVWksNt6WWz54wYUJYv+2220prqWWcN27cGNY///zzsB79vztdNATb6PLWXV1dYX3SpEmltSeffDJsG/28pXTfh8OUYI78QKYIP5Apwg9kivADmSL8QKYIP5Apwg9kaliN80dLZEdLa0tSf39/WB8/Pl6GcPr06aW12bNnh23Hjh0b1tesWRPWd+/eHdaj8wBaPZ04NbU1mjo7atSosG1qufXUtNxp06aV1mbOnBm2jc4RkNJLfx85ciSsd8K5Gxz5gUwRfiBThB/IFOEHMkX4gUwRfiBThB/IlLVzWWkza+jBRowYUVqbOHFi2PaFF14I63Pnzg3rl156aWntxIkTYdvUOQap+fzvvPNOWI/Of2j1vPLUvPapU6eW1np6esK2V1xxRVi/+OKLw3q0bXvq9yUltUbD/Pnzw/r27dtLa8eOHaurT2e5e03rinPkBzJF+IFMEX4gU4QfyBThBzJF+IFMEX4gU8n5/GY2RdISSd2SXNIid/+VmU2Q9IqkHkl9kua6+1et62o8hzq1xfbKlSvDem9vb1i/8sorS2uXXXZZ2DY1Hp2a157aUyC1J0GVojn30Ti8lH7eUqJzEFLnP6TG2jds2BDW9+zZE9Y74WdWy5H/lKS/cffrJc2U9FMzu17SU5JWufs1klYV1wEME8nwu/tud/+wuNwvaZukyZJmS1pcfNtiSXNa1UkAzXde7/nNrEfSTZLWSup297PrS+3RwNsCAMNEzWv4mdloScsk/czdDw9eu83dvey8fTNbIGlBox0F0Fw1HfnN7GINBP937r68uHmvmU0q6pMk7RuqrbsvcvcZ7j6jGR0G0BzJ8NvAIf5FSdvc/ZeDSq9LOjt1ab6k15rfPQCtkpzSa2a9ktZI2iTp7PjI0xp43//vkq6S9JkGhvoOJO6rffOHz5EajpszJ/688t577y2tpbbobnRqajSVWUovn92pUsNtqXpquCyaSp1aLn3Tpk1hfcmSJWH9s88+C+utnEpf65Te5Ht+d/8vSWV3dvf5dApA5+AMPyBThB/IFOEHMkX4gUwRfiBThB/I1LBauruVurq6wnq0zfbtt98etk0tUT158uSwfuedd4b1aEpwamntlNQ5CKnpxtHW6R9//HHYdvPmzWH94MGDdT/2e++9F7Y9fPhwWG90ee1WYuluACHCD2SK8AOZIvxApgg/kCnCD2SK8AOZYpy/RtGc+tRYd6NLc8+aNSusR0tgp8b5U/VU32+++eaw/u6775bWtmzZErZNzak/fvx4WI+2Lj9wIFx6IlwmvtMxzg8gRPiBTBF+IFOEH8gU4QcyRfiBTBF+IFOM8w8DI0fWvKvaeUut+Z967NQ5Cvv37y+tpdbdH85j7VVinB9AiPADmSL8QKYIP5Apwg9kivADmSL8QKaS4/xmNkXSEkndklzSInf/lZk9K+nHkv5cfOvT7v5W4r4Y57/AROscSIzVV6HWcf5awj9J0iR3/9DMrpC0XtIcSXMlHXH3f661U4T/wkP4O0+t4U+eOubuuyXtLi73m9k2SfEWMwA63nm95zezHkk3SVpb3LTQzDaa2UtmNr6kzQIzW2dm6xrqKYCmqvncfjMbLeldSf/o7svNrFvSlxr4HOAfNPDW4IeJ++Bl/wWGl/2dp2nv+SXJzC6W9Iakt939l0PUeyS94e7TEvdD+C8whL/zNG1ijw1M+3pR0rbBwS8+CDzrEUnxlqoAOkotn/b3SlojaZOkM8XNT0t6TNKNGnjZ3yfpJ8WHg9F9ceQHWqypL/ubhfADrcd8fgAhwg9kivADmSL8QKYIP5Apwg9kivADmSL8QKYIP5Apwg9kivADmSL8QKYIP5Apwg9kqnV7Pw/tS0mfDbr+neK2TtSpfevUfkn0rV7N7Ntf1vqNbZ3P/60HN1vn7jMq60CgU/vWqf2S6Fu9quobL/uBTBF+IFNVh39RxY8f6dS+dWq/JPpWr0r6Vul7fgDVqfrID6AilYTfzO4zsz+a2Q4ze6qKPpQxsz4z22RmG6reYqzYBm2fmW0edNsEM/uDmX1UfB1ym7SK+vasme0qnrsNZvZARX2bYmb/aWZbzWyLmT1Z3F7pcxf0q5Lnre0v+81shKTtku6RtFPSB5Iec/etbe1ICTPrkzTD3SsfEzazOyQdkbTk7G5IZvZPkg64+8+LP5zj3f1vO6Rvz+o8d25uUd/Kdpb+gSp87pq543UzVHHkv0XSDnf/xN1PSFoqaXYF/eh47r5a0oFzbp4taXFxebEGfnnarqRvHcHdd7v7h8Xlfklnd5au9LkL+lWJKsI/WdKfBl3fqc7a8tslrTSz9Wa2oOrODKF70M5IeyR1V9mZISR3bm6nc3aW7pjnrp4dr5uND/y+rdfdb5Z0v6SfFi9vO5IPvGfrpOGaX0v6vga2cdst6RdVdqbYWXqZpJ+5++HBtSqfuyH6VcnzVkX4d0maMuj6d4vbOoK77yq+7pO0QgNvUzrJ3rObpBZf91Xcn//j7nvd/bS7n5H0G1X43BU7Sy+T9Dt3X17cXPlzN1S/qnreqgj/B5KuMbPvmdklkuZJer2CfnyLmY0qPoiRmY2SdK86b/fh1yXNLy7Pl/RahX35hk7ZublsZ2lV/Nx13I7X7t72f5Ie0MAn/h9L+rsq+lDSr6sl/Xfxb0vVfZP0sgZeBp7UwGcjP5I0UdIqSR9J+g9JEzqob/+qgd2cN2ogaJMq6luvBl7Sb5S0ofj3QNXPXdCvSp43zvADMsUHfkCmCD+QKcIPZIrwA5ki/ECmCD+QKcIPZIrwA5n6X2CltpIE5lVcAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "#　字符O的样例图片\n",
    "plt.imshow(train_images[0], cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fa8178ea978>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAEFdJREFUeJzt3X+MVeWdx/HP10EBkYAoOxKKC/Kj2EDWmlFXFg0bpYoxqMSQYtywiSn9Q4xNMFmjQdE/iNnYksaYJjSSorDWJhYl2LgF3Kg12IDERZRtcXEQyCC/RECEOvDdP+bgjjrnecZ7z7nnDs/7lZC593zvc8/DZT6ce+9zzvOYuwtAes6pugMAqkH4gUQRfiBRhB9IFOEHEkX4gUQRfiBRhB9IFOEHEtWvkTszM04nBErm7tabx9V15Dezm83sL2b2oZk9WM9zAWgsq/XcfjNrkfRXSdMl7Za0UdIcd/8g0IYjP1CyRhz5r5b0obvvcPe/SfqtpNvqeD4ADVRP+EdK2tXt/u5s29eY2Twz22Rmm+rYF4CClf6Fn7svlbRU4m0/0EzqOfLvkTSq2/3vZdsA9AH1hH+jpPFmNsbMzpP0Y0mri+kWgLLV/Lbf3TvNbL6k/5TUImmZu79fWM8AlKrmob6adsZnfqB0DTnJB0DfRfiBRBF+IFGEH0gU4QcSRfiBRDX0ev4ynXfeecH6+eefH6wfPXo0WD916tR37hPQzDjyA4ki/ECiCD+QKMIPJIrwA4ki/ECi+tRQX0tLS27tpptuCradOXNmsL5mzZpgfe3atbm148ePB9uiGuecU/ux7fTp0wX2pDlx5AcSRfiBRBF+IFGEH0gU4QcSRfiBRBF+IFF9avbe4cOH59aefvrpYNvYOP/u3buD9eeeey63tnz58mDbnTt3BuuN/Dc4m/Tv3z9Yv/HGG2t+7nXr1gXrJ0+erPm5y8bsvQCCCD+QKMIPJIrwA4ki/ECiCD+QKMIPJKqu6/nNrF3SUUmnJHW6e1sRncozZMiQ3NrkyZODbWNTe1922WXB+gMPPJBbGzduXLDtk08+Gax3dHQE6wcPHgzWz9ZpxQcMGBCsT5gwIVh//PHHa9537LyPLVu2BOt94dyNIibz+Gd3P1DA8wBoIN72A4mqN/wu6Y9m9o6ZzSuiQwAao963/VPdfY+Z/Z2ktWb2P+7+RvcHZP8p8B8D0GTqOvK7+57s5z5JqyRd3cNjlrp7W9lfBgL4bmoOv5kNMrPBZ25L+pGkrUV1DEC56nnb3ypplZmdeZ7/cPdXC+kVgNLVHH533yHpHwrsS1RnZ2duLbbEdmzcNftPLNegQYNya3feeWewbewchO3btwfrK1asCNbP1jUFYuP4d9xxR7A+ceLE3FpsTv/QeR2StHDhwmC9vb09WG8GDPUBiSL8QKIIP5Aowg8kivADiSL8QKL61NTdoamaZ82aFWwbu7wzdklvbCiwHrFLcmNTf9czrXiVQ1L9+oVHmu+///5gfc6cOcH6lVde+Z37dMbnn38erD/yyCPB+lNPPRWsh4at68XU3QCCCD+QKMIPJIrwA4ki/ECiCD+QKMIPJKqI2XsbJrQs8ksvvRRsG5u6+9FHHw3WL7300txa7PLQmJaWlmA9dg7CggULcmuxsfTYeQC7du0K1utZqnrYsGHB+rXXXhusjxo1quZ9x8R+Xy644ILS9t0oHPmBRBF+IFGEH0gU4QcSRfiBRBF+IFGEH0hUnxrnDzl9+nSwHlty+aOPPgrWQ8uDDx06NNi2zLkAJGngwIG5tSlTpgTbHj58OFhftWpVsB6bayD07zJixIhg29jS54MHDw7W6xG73v7YsWOl7btROPIDiSL8QKIIP5Aowg8kivADiSL8QKIIP5Co6Lz9ZrZM0q2S9rn7pGzbMEkvSBotqV3SbHf/NLqzOuftDxk5cmSwvmTJkmD9iy++CNY3b96cW7v33nuDbUePHh2sx+YDqGe+gNj5D7G/97p164L1lStXBusff/xxbu2+++4Lto0tfR5ax6FeO3bsCNZnz54drId+X8pW5Lz9v5F08ze2PShpvbuPl7Q+uw+gD4mG393fkHToG5tvk3RmCpjlkm4vuF8ASlbr+8lWd+/Ibu+V1FpQfwA0SN3n9ru7hz7Lm9k8SfPq3Q+AYtV65P/EzEZIUvZzX94D3X2pu7e5e1uN+wJQglrDv1rS3Oz2XEkvF9MdAI0SDb+ZPS9pg6Tvm9luM7tH0hOSppvZdkk3ZvcB9CHRz/zunrcI+g0F96UuBw8eDNY3btwYrMeuz16xYkVuLTYmfP311wfrY8aMCdanTp0arA8fPjy3FjtHYNCgQcH6rbfeGqzH5gsIrXMfu56/zHH82HoDb775ZrAeOn+hr+AMPyBRhB9IFOEHEkX4gUQRfiBRhB9IVPSS3kJ3VuIlvTGx5aBjDh365rVN/y82nBYbsopNQT1jxoxg/eGHH86txYYRY8uDN7PYlOih3+0tW7YE286dOzdYj7VvZK562Hdhl/QCOAsRfiBRhB9IFOEHEkX4gUQRfiBRhB9I1FmzRHdMaJy+XvVOjx2rv/DCC8F6aJntu+++O9j2uuuuC9ZDlwtL8bH2spcnDzl16lRu7bXXXgu23b59e7Be5Th+UTjyA4ki/ECiCD+QKMIPJIrwA4ki/ECiCD+QqGTG+fuyEydOBOtr1qzJrW3YsCHYdubMmcH64sWLg/WLL744WK9S6ByDcePGBdtOmDAhWN+2bVuwHpsavBlw5AcSRfiBRBF+IFGEH0gU4QcSRfiBRBF+IFHRefvNbJmkWyXtc/dJ2bZFkn4iaX/2sIfc/Q/RnVU4bz96NnHixGB99erVwXpsvLzK6/lDv9uxcfjYOP5jjz0WrL/yyivBemdnZ7BejyLn7f+NpJt72L7E3a/I/kSDD6C5RMPv7m9IKm8aHACVqOcz/3wz22Jmy8zswsJ6BKAhag3/rySNlXSFpA5JP897oJnNM7NNZrapxn0BKEFN4Xf3T9z9lLuflvRrSVcHHrvU3dvcva3WTgIoXk3hN7MR3e7eIWlrMd0B0CjRS3rN7HlJ0yRdbGa7JT0qaZqZXSHJJbVL+mmJfQRQgmj43X1OD5ufKaEvKME554Tf3F1++eXB+pAhQ+rafz3z24fm3Zfi5xCE/u79+/cPtp08eXKwPn/+/GD9yJEjwfrrr7+eW4utA1EUzvADEkX4gUQRfiBRhB9IFOEHEkX4gUQxdfdZLjakNWXKlGD9wgvLu2zjwIEDwXrsstjYtOE33HBDbm3AgAHBti0tLcF6bGnzGTNmBOtvv/12bi22ZHtROPIDiSL8QKIIP5Aowg8kivADiSL8QKIIP5AoxvnPcuPHjw/Wp0+fHqz361ffr0jostzQZa2StGjRomA9dg7DwoULc2uzZs0Kth04cGCwfu655wbr06ZNC9ZbW1tza+3t7cG2ReHIDySK8AOJIvxAogg/kCjCDySK8AOJIvxAohjnPwuEpqgeO3ZssO0ll1xSdHe+5ssvv8ytbd0aXutl//79wXpsWvBXX301t3bNNdcE244ZMyZYj2nU9Nv14MgPJIrwA4ki/ECiCD+QKMIPJIrwA4ki/ECiouP8ZjZK0rOSWiW5pKXu/kszGybpBUmjJbVLmu3un5bXVeQJjfNPmjQp2LbMefklae/evbm1NWvWBNseP368rn2/+OKLubVjx44F295111117XvlypXBekdHR13PX4TeHPk7JS1w9x9I+kdJ95rZDyQ9KGm9u4+XtD67D6CPiIbf3TvcfXN2+6ikbZJGSrpN0vLsYcsl3V5WJwEU7zt95jez0ZJ+KOnPklrd/cx7l73q+lgAoI/o9bn9ZnaBpBcl/czdj5jZVzV3dzPr8URrM5snaV69HQVQrF4d+c3sXHUFf6W7/z7b/ImZjcjqIyTt66mtuy919zZ3byuiwwCKEQ2/dR3in5G0zd1/0a20WtLc7PZcSS8X3z0AZenN2/5/kvQvkt4zs3ezbQ9JekLS78zsHkk7Jc0up4uoR2gYUJK6f3wrw8mTJ3Nrn332Wan7PnHiRG4tNsz41ltv1bXvTz8Nj3qHpjRvlGj43f1PkvJ+Q/IXQAfQ1DjDD0gU4QcSRfiBRBF+IFGEH0gU4QcSxdTdqEtoLF2SNmzYkFs7fPhw0d3ptdg4+4EDBxrUk+pw5AcSRfiBRBF+IFGEH0gU4QcSRfiBRBF+IFGM85/ljhw5EqyHrreX4vMBrF27NlhfvHhxbu3gwYPBtigXR34gUYQfSBThBxJF+IFEEX4gUYQfSBThBxLFOP9ZoLOzM7e2atWqYNvJkycH61dddVWwvn79+mC9vb09t+be4wpvaBCO/ECiCD+QKMIPJIrwA4ki/ECiCD+QKMIPJMpiY61mNkrSs5JaJbmkpe7+SzNbJOknkvZnD33I3f8QeS4GdhvMLG919S4XXXRRsD506NBg/dChQ3XVUTx3D/+jZ3pzkk+npAXuvtnMBkt6x8zOzOCwxN2frLWTAKoTDb+7d0jqyG4fNbNtkkaW3TEA5fpOn/nNbLSkH0r6c7ZpvpltMbNlZnZhTpt5ZrbJzDbV1VMAhep1+M3sAkkvSvqZux+R9CtJYyVdoa53Bj/vqZ27L3X3NndvK6C/AArSq/Cb2bnqCv5Kd/+9JLn7J+5+yt1PS/q1pKvL6yaAokXDb11fFz8jaZu7/6Lb9hHdHnaHpK3Fdw9AWXoz1DdV0puS3pN0Otv8kKQ56nrL75LaJf00+3Iw9FwM9QEl6+1QXzT8RSL8QPl6G37O8AMSRfiBRBF+IFGEH0gU4QcSRfiBRBF+IFGEH0gU4QcSRfiBRBF+IFGEH0gU4QcSRfiBRDV6ie4DknZ2u39xtq0ZNWvfmrVfEn2rVZF9+/vePrCh1/N/a+dmm5p1br9m7Vuz9kuib7Wqqm+87QcSRfiBRFUd/qUV7z+kWfvWrP2S6FutKulbpZ/5AVSn6iM/gIpUEn4zu9nM/mJmH5rZg1X0IY+ZtZvZe2b2btVLjGXLoO0zs63dtg0zs7Vmtj372eMyaRX1bZGZ7cleu3fN7JaK+jbKzP7LzD4ws/fN7P5se6WvXaBflbxuDX/bb2Ytkv4qabqk3ZI2Sprj7h80tCM5zKxdUpu7Vz4mbGbXSzom6Vl3n5Rt+3dJh9z9iew/zgvd/d+apG+LJB2reuXmbEGZEd1XlpZ0u6R/VYWvXaBfs1XB61bFkf9qSR+6+w53/5uk30q6rYJ+ND13f0PSNxe4v03S8uz2cnX98jRcTt+agrt3uPvm7PZRSWdWlq70tQv0qxJVhH+kpF3d7u9Wcy357ZL+aGbvmNm8qjvTg9ZuKyPtldRaZWd6EF25uZG+sbJ007x2tax4XTS+8Pu2qe5+paQZku7N3t42Je/6zNZMwzW9Wrm5UXpYWforVb52ta54XbQqwr9H0qhu97+XbWsK7r4n+7lP0io13+rDn5xZJDX7ua/i/nylmVZu7mllaTXBa9dMK15XEf6Nksab2RgzO0/SjyWtrqAf32Jmg7IvYmRmgyT9SM23+vBqSXOz23MlvVxhX76mWVZuzltZWhW/dk234rW7N/yPpFvU9Y3//0p6uIo+5PTrMkn/nf15v+q+SXpeXW8Dv1TXdyP3SLpI0npJ2yWtkzSsifr2nLpWc96irqCNqKhvU9X1ln6LpHezP7dU/doF+lXJ68YZfkCi+MIPSBThBxJF+IFEEX4gUYQfSBThBxJF+IFEEX4gUf8HwAJN/e+uiLMAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 字符X的样例图片\n",
    "plt.imshow(train_images[3], cmap='gray')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  Label数据"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "关于label数据，标签数据数值范围是0-1, 0代表字符`O`, 1代表字符`X`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 1, 1, 1, 0, 1, 1, 1, 0], dtype=uint8)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 打印前10个label数据\n",
    "train_labels[0:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据预处理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 归一化\n",
    "将样本数据原来`0-255`数值范围的整数格式`uint8`，转换为`0-1.0`的浮点数`float32`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_images = train_images / 255.0\n",
    "test_images = test_images / 255.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构造one-hot矩阵\n",
    "\n",
    "将原来的标签数组转换为one-hot矩阵"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_labels = to_categorical(train_labels)\n",
    "test_labels = to_categorical(test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "打印一下转换后的one-hot矩阵的前10个样本，可以看到train_label被转换成了`Nx2`的结构， 第一列代表是不是字符`O`, 是的话对应的列就是`1`, 否则就是`0`. 同样,矩阵的第二列代表是否为字符`X`，是的话对应的列就是`1`, 否则就是`0`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1., 0.],\n",
       "       [0., 1.],\n",
       "       [0., 1.],\n",
       "       [0., 1.],\n",
       "       [0., 1.],\n",
       "       [1., 0.],\n",
       "       [0., 1.],\n",
       "       [0., 1.],\n",
       "       [0., 1.],\n",
       "       [1., 0.]], dtype=float32)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_labels[:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 生成Batch数据\n",
    "\n",
    "在用Tensorflow进行模型训练的时候，需要生成Batch数据，就是模型迭代一次权重所需要的样本。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def next_batch(train_images, train_labels, batch_size):\n",
    "    # 生成索引数组 [0, 1, 2, 3, 4. ...., N]\n",
    "    idxs = np.array([i for i in range(len(train_images))])\n",
    "    # 将原来的数组重新打乱顺序\n",
    "    np.random.shuffle(idxs)\n",
    "    # 取前batch_size个索引\n",
    "    top_idxs = idxs[:batch_size]\n",
    "    # 根据索引数组得到训练集的图像\n",
    "    return train_images[top_idxs], train_labels[top_idxs]\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 权重初始化\n",
    "\n",
    "在神经网络模型训练之前，需要给神经网络的权重设置一个初始值，一般采用随机生成的方式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def weight_variable(shape):\n",
    "    '''\n",
    "    权重初始化  weight init\n",
    "    初始化为一个接近0的很小的正数  init to a small number close to zero\n",
    "    '''\n",
    "    initial = tf.truncated_normal(shape, stddev = 0.1)\n",
    "    return tf.Variable(initial)\n",
    "\n",
    "def bias_variable(shape):\n",
    "    '''Bias变量初始化'''\n",
    "    initial = tf.constant(0.1, shape = shape)\n",
    "    return tf.Variable(initial)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义神经网络的结构\n",
    "\n",
    "在设计神经网络结构的时候，一定要去查一下**K210当下是否支持相应的Layer与激活函数**.\n",
    "\n",
    "你可以在[kendryte/nncase](https://github.com/kendryte/nncase/)的README页面查看K210支持的层**Support Layer**.\n",
    "\n",
    "另外，对于神经网络来说,用到的层只有全连接层**FullyConnected**。\n",
    "\n",
    "K210所支持的激活函数如下\n",
    "* Sigmoid\n",
    "* Relu\n",
    "* Relu6\n",
    "* LeakyRelu\n",
    "* Softmax\n",
    "\n",
    "激活函数，在当前的这个模型里面，用到了`Relu` 还有`Softmax`. \n",
    "\n",
    "K210支持的卷积层，池化层等其他层，请查阅文档。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "输入层的尺寸为 Nx28x28， 每一个输入样本都是一个28x28的矩阵"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 输入的是28x28的图片 通道数为1\n",
    "input_size = [None, 28, 28]\n",
    "# Dropout的概率\n",
    "keep_prob = tf.placeholder(\"float\")\n",
    "# 输入数据\n",
    "x = tf.placeholder(tf.float32, input_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来需要对模型进行变换，数据由原来的`28x28`的二维数组，转换为`28x28x1`的三维数组。\n",
    "\n",
    "之所以这样做， 是跟后期模型转换的时候有关，后面的教程我们再谈。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_image = tf.reshape(x, [-1, 28, 28, 1])  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来是Flatten操作，将原来的28x28x1的结构变为1维向量。 与卷积神经网络不同， 神经网络并不会参考图像的2D空间特征，只是将其当做一个向量来处理。\n",
    "\n",
    "那为什么我们不在图像预处理的时候就转换成1维的数据呢？ \n",
    "\n",
    "嘻嘻,是为了最大限度的复用K210的提供的Demo例程。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-14-061b67464529>:2: flatten (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.flatten instead.\n"
     ]
    }
   ],
   "source": [
    "# Flatten层\n",
    "x_flat = tf.layers.flatten(x_image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来是神经网络的全连接层(Fully Connected Layer),作为第一层隐藏层Hidden Layer。神经网络除了输入层与输出层，中间的均称之为隐藏层。\n",
    "全连接层的运算，可以转换为**矩阵运算**, 写成$y = X*W + b$的形式，　其中`X`是上一层的输入, `W`是权重矩阵, `b`是偏移量`bias`.\n",
    "\n",
    "\n",
    "转变为矩阵运算之后可以大幅提高计算速度，在权值保存的时候也更方便。\n",
    "\n",
    "神经网络有各种各样的激活函数**Active Function**，这里我们使用`Relu`作为神经网络的激活函数."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n"
     ]
    }
   ],
   "source": [
    "# 输入向量的维度\n",
    "IMAGE_RESHAPED = 784\n",
    "# 每层隐藏层神经元的个数\n",
    "N_HIDDEN = 128\n",
    "\n",
    "# 第一层 全连接层1\n",
    "W_fc1 = weight_variable([IMAGE_RESHAPED, N_HIDDEN])\n",
    "b_fc1 = bias_variable([N_HIDDEN])\n",
    "h_fc1 = tf.nn.relu(tf.matmul(x_flat, W_fc1) + b_fc1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`IMAGE_RESHAPED=784`是`28x28x1=784`的数据flatten之后的尺寸, `N_HIDDEN`是隐藏层神经网络的个数, 一共128个神经元，同时也对应着128个bias."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在模型训练的时候，还可以加上`Dropout`层，随机删除神经元，　网络表现力强，而且可以防止过拟合。\n",
    "\n",
    "K210并不支持`Dropout`层，但是训练的时候可以加。\n",
    "\n",
    "在模型训练结束之后，再将Dropout层删除。\n",
    "\n",
    "`keep_prob`代表每次迭代的时候神经元保留的概率."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-16-31fb27abee06>:2: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n"
     ]
    }
   ],
   "source": [
    "# Dropout\n",
    "h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "第二层隐藏层的结构跟第一层的很类似，上一层是`N_HIDDEN`个神经元, 当前的隐藏层同样也有`N_HIDDEN`个"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 第二层 全连接层2\n",
    "W_fc2 = weight_variable([N_HIDDEN, N_HIDDEN])\n",
    "b_fc2 = bias_variable([N_HIDDEN])\n",
    "h_fc2 = tf.nn.relu(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)\n",
    "h_fc2_drop = tf.nn.dropout(h_fc2, keep_prob)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后一层是输出层，该层也是全连接层，上一层神经元的节点个数是`N_HIDDEN`个，当前层的神经元个数为2,分别代表`O`跟`X`.\n",
    "\n",
    "激活函数使用的是`Softmax`, 使输出结果正规化。令输出结果范围在0-1之间, 而且所有的输出层结果数值相加等于1，相当于每个输出代表是这个类的概率，　`y_nn`就是最终的输出。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 第三层是输出层，　\n",
    "W_fc3 = weight_variable([N_HIDDEN, 2])\n",
    "b_fc3 = bias_variable([2])\n",
    "y_nn = tf.nn.softmax(tf.matmul(h_fc2_drop, W_fc3) + b_fc3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型训练与评估"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 真实的label数据 one_hot矩阵\n",
    "y_ref = tf.placeholder(\"float\", [None, 2])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "选用交叉熵**Cross Entroy**作为损失函数**Lost Function**, 损失函数是权重优化的依据，根据损失函数的梯度来优化参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 交叉熵\n",
    "cross_entropy = -tf.reduce_sum(y_ref * tf.log(y_nn))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用`ADAM`优化器来迭代更新权重, 制定学习率为`0.0001`, 定义ADam优化器的目标是使得损失函数最小化\n",
    "\n",
    "这里需要注意的是优化器的选择并不受K210的约束, 只在模型训练，权值迭代的时候产生作用,并不属于网络结构中的一部分。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate = 0.0001 # 学习率\n",
    "# 定义优化器为ADam\n",
    "train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy) #使用adam优化器来以0.0001的学习率来进行微调"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "计算准确率**Accuracy** `acc`, 作为评价模型的指标。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 是否准确的布尔值列表\n",
    "correct_prediction = tf.equal(tf.argmax(y_nn,1), tf.argmax(y_ref,1)) #判断预测标签和实际标签是否匹配\n",
    "#　计算平均值即为准确率\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction,\"float\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "tf.Saver的使用可以看这篇Blog : [Tensorflow 模型的保存,读取和冻结,执行](https://www.jarvis73.cn/2018/04/25/Tensorflow-Model-Save-Read/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "saver = tf.train.Saver()\n",
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer()) #初始化变量\n",
    "merged = tf.summary.merge_all() \n",
    "# 制定检查点等模型训练日志存放在同级目录下的logs文件夹\n",
    "writer = tf.summary.FileWriter('logs',sess.graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "重复训练500个Batch,　每隔100个Batch输出一次日志"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0, train_accuracy 0.593333\n",
      "step 100, train_accuracy 0.993333\n",
      "step 200, train_accuracy 0.986667\n",
      "step 300, train_accuracy 0.993333\n",
      "step 400, train_accuracy 0.993333\n",
      "step 500, train_accuracy 0.993333\n"
     ]
    }
   ],
   "source": [
    "for i in range(501): #开始训练模型，循环训练500次\n",
    "    batch = next_batch(train_images, train_labels, 150) #batch大小设置为50\n",
    "    if i % 100 == 0:\n",
    "        # 在测试的准确率的时候，需要设置keep_prob为1\n",
    "        train_accuracy = accuracy.eval(session = sess,\n",
    "                                       feed_dict = {x:batch[0], y_ref:batch[1], keep_prob:1.0})\n",
    "        print(\"step %d, train_accuracy %g\" %(i, train_accuracy))\n",
    "        # 保存检查点\n",
    "        saver.save(sess, 'model/nn_xo.ckpt')\n",
    "    # feed_dict中加入参数`keep_prob`控制`dropout`比例。\n",
    "    train_step.run(session = sess, feed_dict = {x:batch[0], y_ref:batch[1],\n",
    "                   keep_prob:0.7}) #神经元输出保持不变的概率 keep_prob 为0.5\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最终拿测试集来评价一下模型的准确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use standard file APIs to check for files with this prefix.\n",
      "INFO:tensorflow:Restoring parameters from model/nn_xo.ckpt\n",
      "test accuracy 0.99125\n"
     ]
    }
   ],
   "source": [
    "sess = tf.InteractiveSession()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "saver = tf.train.Saver(tf.global_variables())\n",
    "saver.restore(sess, 'model/nn_xo.ckpt')\n",
    "# 查看测试集的准确率\n",
    "print( \"test accuracy %g\" % accuracy.eval(feed_dict={\n",
    "    x:test_images, y_ref:test_labels, keep_prob:1.0}))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型冻结"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在模型训练完成之后, 所有的网络的权重还有Bias都不再变更了。接下来我们需要通过**模型冻结 Freeze**\n",
    "模型冻结的其中一种方法是将原来的变量`tf.variable`转变为常量`tf.constant`,并将权重保存在模型文件里面。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过`eval()`函数可以将tf.Variable导出权重的ndarray, 如下所示"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.0987657 , 0.09167684, 0.10189684, 0.09403343, 0.10781259,\n",
       "       0.10423934, 0.10627533, 0.10272716, 0.09875294, 0.09838467,\n",
       "       0.11374552, 0.10951015, 0.10813315, 0.11166847, 0.11739392,\n",
       "       0.09535608, 0.10067159, 0.09764701, 0.09580787, 0.10748729],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 只输出前20个Bias\n",
    "b_fc1.eval()[:20]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "获得权重矩阵"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取权重矩阵\n",
    "WFC1 = W_fc1.eval()\n",
    "BFC1 = b_fc1.eval()\n",
    "WFC2 = W_fc2.eval()\n",
    "BFC2 = b_fc2.eval()\n",
    "WFC3 = W_fc3.eval()\n",
    "BFC3 = b_fc3.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，我们需要创建一个新的图, 做如下调整：\n",
    "\n",
    "1. 将原来的变量改为常量，并命名\n",
    "2. 命名输入层跟输出层。　这里输入层命名为`input_node`, 输出层命名为`ouput_node`\n",
    "3. 删除一些层，例如Dropout层只在模型训练的时候有用，前向计算就没了用处。\n",
    "4. 改变输入层的结构，训练的时候输入层是针对样本数据的,在实际部署的时候结构为`[-1, 28, 28, 1]`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建一个Graph对象\n",
    "g = tf.Graph()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "with g.as_default():\n",
    "    # 定义输入节点\n",
    "    x_image = tf.placeholder(\"float\", shape=[None, 28, 28, 1], name=\"input_node\")\n",
    "    # 第一层隐藏层\n",
    "    WFC1 = tf.constant(WFC1, name=\"WFC1\")\n",
    "    BFC1 = tf.constant(BFC1, name=\"BFC1\")\n",
    "    FC1 = tf.layers.flatten(x_image)\n",
    "    FC1 = tf.nn.relu(tf.matmul(FC1, WFC1) + BFC1)\n",
    "    # 第二层隐藏层\n",
    "    WFC2 = tf.constant(WFC2, name=\"WFC2\")\n",
    "    BFC2 = tf.constant(BFC2, name=\"BFC2\")\n",
    "    FC2 = tf.nn.relu(tf.matmul(FC1, WFC2) + BFC2)\n",
    "    # 第三层输出层\n",
    "    WFC3 = tf.constant(WFC3, name=\"WFC3\")\n",
    "    BFC3 = tf.constant(BFC3, name=\"BFC3\")\n",
    "    # 定义输出节点\n",
    "    OUTPUT = tf.nn.softmax(tf.matmul(FC2, WFC3) + BFC3, name=\"output_node\")\n",
    "    \n",
    "    \n",
    "    sess = tf.Session()\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "\n",
    "    graph_def = g.as_graph_def()\n",
    "    # 保存模型,存放在同级目录下\n",
    "    tf.train.write_graph(graph_def, \"./\", \"nn_xo.pb\", as_text=False)\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
